DALL·E Mini
Contents
DALL·E Mini#
This is an simple way of creating DALL·E Mini artworks for generative artists.
Note
Install ekorpkit package first.
Set logging level to Warning, if you don’t want to see verbose logging.
If you run this notebook in Colab, set Hardware accelerator to GPU.
Check your jaxlib version and install the appropriate version. for example,
pip install “jax[cuda11_cudnn82]” -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install ekorpkit[art] exit()
from ekorpkit import eKonf
eKonf.setLogger("INFO")
eKonf.set_cuda(device="4,5")
print("version:", eKonf.__version__)
is_notebook = eKonf.is_notebook()
is_colab = eKonf.is_colab()
print("is notebook?", is_notebook)
print("is colab?", is_colab)
if is_colab:
eKonf.mount_google_drive(
workspace="MyDrive/colab_workspace", project="disco-imagen"
)
print("environment variables:")
eKonf.print(eKonf.env().dict())
INFO:ekorpkit.base:Setting cuda device to ['A100-SXM4-40GB', 'A100-SXM4-40GB']
INFO:ekorpkit.base:Google Colab not detected.
version: 0.1.36+4.gafa5b2d.dirty
is notebook? True
is colab? False
evironment varialbles:
{'CUDA_DEVICE_ORDER': 'PCI_BUS_ID',
'CUDA_VISIBLE_DEVICES': '4, 5',
'EKORPKIT_CONFIG_DIR': '/workspace/projects/ekorpkit-book/config',
'EKORPKIT_DATA_DIR': None,
'EKORPKIT_LOG_LEVEL': 'INFO',
'EKORPKIT_PROJECT': 'ekorpkit-book',
'EKORPKIT_WORKSPACE_ROOT': '/workspace',
'KMP_DUPLICATE_LIB_OK': 'TRUE',
'NUM_WORKERS': 230}
cfg = eKonf.compose("model/dalle_mini")
dalle = eKonf.instantiate(cfg)
INFO:ekorpkit.base:Loaded .env from /workspace/projects/ekorpkit-book/config/.env
INFO:ekorpkit.base:setting environment variable CACHED_PATH_CACHE_ROOT to /workspace/.cache/cached_path
INFO:ekorpkit.base:setting environment variable KMP_DUPLICATE_LIB_OK to TRUE
INFO:ekorpkit.base:Google Colab not detected.
INFO:ekorpkit.models.art.base:> downloading models...
INFO:ekorpkit.models.art.base:> loading modules...
INFO:ekorpkit.utils.lib:dalle_mini not imported, loading from /workspace/projects/ekorpkit-book/disco-imagen/libs/dalle-mini/src/dalle_mini as dalle_mini
INFO:ekorpkit.utils.lib:vqgan_jax.modeling_flax_vqgan not imported, loading from /workspace/projects/ekorpkit-book/disco-imagen/libs/vqgan-jax as vqgan_jax.modeling_flax_vqgan
INFO:ekorpkit.models.art.base:> loading models...
INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
INFO:absl:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA Host
INFO:absl:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
INFO:ekorpkit.models.art.mini:Available devices: 6
INFO:ekorpkit.models.art.mini:Using 6 devices
INFO:ekorpkit.models.art.mini:Devices: [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0), GpuDevice(id=4, process_index=0), GpuDevice(id=5, process_index=0)]
wandb: Downloading large artifact mega-1-fp16:latest, 4938.53MB. 7 files... Done. 0:0:11.7
Some of the weights of DalleBart were initialized in float16 precision from the model checkpoint at /tmp/tmpaijz02pf:
[('lm_head', 'kernel'), ('model', 'decoder', 'embed_positions', 'embedding'), ('model', 'decoder', 'embed_tokens', 'embedding'), ('model', 'decoder', 'final_ln', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'scale'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'k_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'out_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'q_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_0', 'v_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'k_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'out_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'q_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'FlaxBartAttention_1', 'v_proj', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'Dense_0', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'Dense_1', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'Dense_2', 'kernel'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'LayerNorm_0', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'GLU_0', 'LayerNorm_1', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_0', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_1', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_1', 'scale'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_2', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_3', 'bias'), ('model', 'decoder', 'layers', 'FlaxBartDecoderLayers', 'LayerNorm_3', 'scale'), ('model', 'encoder', 'embed_positions', 'embedding'), ('model', 'encoder', 'embed_tokens', 'embedding'), ('model', 'encoder', 'final_ln', 'bias'), ('model', 'encoder', 'layernorm_embedding', 'bias'), ('model', 'encoder', 'layernorm_embedding', 'scale'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'k_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'out_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'q_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'FlaxBartAttention_0', 'v_proj', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'Dense_0', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'Dense_1', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'Dense_2', 'kernel'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'LayerNorm_0', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'GLU_0', 'LayerNorm_1', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'LayerNorm_0', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'LayerNorm_1', 'bias'), ('model', 'encoder', 'layers', 'FlaxBartEncoderLayers', 'LayerNorm_1', 'scale')]
You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.
wandb: Downloading large artifact mega-1-fp16:latest, 4938.53MB. 7 files... Done. 0:0:14.3
# text_prompts = 'At a special meeting, dovish central bankers are poised to cut the target rates. Trending on artstation, matte style'
# text_prompts = "Dovish members of the Federal Reserve Board are to cut the target interest rates. matte style, artstation"
# batch_name = "dovish"
# text_prompts = 'At a special meeting, hawkish central bankers are poised to raise the target rates. Trending on artstation, matte style'
# batch_name = "hawkish"
# text_prompts = "Mt. Halla's beautiful flowers, artstation, matte"
# batch_name = "halla"
text_prompts = "Brave new world"
batch_name = "newworld"
dalle.imagine(
text_prompts,
batch_name=batch_name,
n_samples=6,
show_collage=True,
)
INFO:ekorpkit.models.art.mini: >> elapsed time to diffuse: 0:00:34.757086
INFO:ekorpkit.models.art.base:Merging config with args: {}
INFO:ekorpkit.models.art.base:Prompt: Brave new world
6 samples generated to /workspace/projects/ekorpkit-book/disco-imagen/outputs/dalle-mini/newworld
text prompts: ['Brave new world']
sample image paths:
/workspace/projects/ekorpkit-book/disco-imagen/outputs/dalle-mini/newworld/newworld(0)_0000.png
/workspace/projects/ekorpkit-book/disco-imagen/outputs/dalle-mini/newworld/newworld(0)_0001.png
/workspace/projects/ekorpkit-book/disco-imagen/outputs/dalle-mini/newworld/newworld(0)_0002.png
/workspace/projects/ekorpkit-book/disco-imagen/outputs/dalle-mini/newworld/newworld(0)_0003.png
/workspace/projects/ekorpkit-book/disco-imagen/outputs/dalle-mini/newworld/newworld(0)_0004.png
/workspace/projects/ekorpkit-book/disco-imagen/outputs/dalle-mini/newworld/newworld(0)_0005.png
collage generated sample images#
dalle.collage(
batch_name=batch_name,
batch_num=5,
ncols=3,
num_images=6,
show_filename=True,
fontcolor="white",
)
INFO:ekorpkit.models.dalle.base:Loading config from /workspace/projects/ekorpkit-book/disco-imagen/outputs/dalle-mini/halla/halla(5)_settings.yaml
INFO:ekorpkit.models.dalle.base:Merging config with diffuse defaults
INFO:ekorpkit.models.dalle.base:Prompt: Mt. Halla's beautiful flowers, artstation, matte
INFO:ekorpkit.io.file:Processing [6] files from ['halla(5)_*.png']
show config#
dalle.show_config(batch_name=batch_name, batch_num=0)
INFO:ekorpkit.models.dalle.base:Loading config from /workspace/projects/ekorpkit-book/disco-imagen/outputs/dalle-mini/halla/halla(0)_settings.yaml
INFO:ekorpkit.models.dalle.base:Merging config with diffuse defaults
{'batch_name': 'halla',
'batch_num': 8,
'cond_scale': 10.0,
'gen_top_k': None,
'gen_top_p': None,
'n_samples': 6,
'num_samples': 6,
'resume_run': False,
'run_to_resume': 'latest',
'seed': 2301282676,
'set_seed': 'random_seed',
'show_collage': True,
'temperature': None,
'text_prompts': ["Mt. Halla's beautiful flowers, photorealistic"]}